Analyze Path-Integrating Recurrent Neural Networks#

Set Up + Imports#

 In [46]:
import setup

setup.main()
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

import neurometry.datasets.synthetic as synthetic
import numpy as np
import skdim

from neurometry.dimension.dimension import skdim_dimension_estimation
from neurometry.dimension.dimension import plot_dimension_experiments

import matplotlib.pyplot as plt


import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs
import torch
Working directory:  /home/facosta/neurometry/neurometry
Directory added to path:  /home/facosta/neurometry
Directory added to path:  /home/facosta/neurometry/neurometry
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The jupyter_black extension is already loaded. To reload it, use:
  %reload_ext jupyter_black

Single-Agent RNN#

Load activations across training epochs#

 In [47]:
import sys

path = os.getcwd() + "/datasets/rnn_grid_cells"
sys.path.append(path)
from neurometry.datasets.load_rnn_grid_cells import load_activations
 In [48]:
epochs = list(range(0, 100, 5))

epochs.append("final")
(
    single_agent_activations,
    single_agent_rate_maps,
    single_agent_state_points,
) = load_activations(epochs, version="single", verbose=True)
Epoch 0 found!!! :D
Epoch 5 found!!! :D
Epoch 10 found!!! :D
Epoch 15 found!!! :D
Epoch 20 found!!! :D
Epoch 25 found!!! :D
Epoch 30 found!!! :D
Epoch 35 found!!! :D
Epoch 40 found!!! :D
Epoch 45 found!!! :D
Epoch 50 found!!! :D
Epoch 55 found!!! :D
Epoch 60 found!!! :D
Epoch 65 found!!! :D
Epoch 70 found!!! :D
Epoch 75 found!!! :D
Epoch 80 found!!! :D
Epoch 85 found!!! :D
Epoch 90 found!!! :D
Epoch 95 found!!! :D
Epoch final found!!! :D
Loaded epochs [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 'final'] of single agent model.
There are 4096 grid cells with 20 x 20 environment resolution, averaged over 50 trajectories.
There are 20000 data points in the 4096-dimensional state space.
There are 400 data points averaged over 50 trajectories in the 4096-dimensional state space.

Plot final activations#

 In [49]:
from neurometry.datasets.load_rnn_grid_cells import plot_rate_map

plot_rate_map(None, 40, single_agent_activations[-1])
../_images/notebooks_07_application_rnns_grid_cells_8_0.png

Load Training Loss#

 In [50]:
model_folder = "Single agent path integration/Seed 1 weight decay 1e-06/"
model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"

loss_path = (
    "/scratch/facosta/rnn_grid_cells/" + model_folder + model_parameters + "loss.npy"
)

loss = np.load(loss_path)

loss_aggregated = np.mean(loss.reshape(-1, 1000), axis=1)

loss_normalized = (loss_aggregated - np.min(loss_aggregated)) / (
    np.max(loss_aggregated) - np.min(loss_aggregated)
)

plt.plot(np.linspace(0, 100, 100), loss_normalized)
plt.xlabel("Epochs")
plt.ylabel("Normalized trainingloss")
plt.title("Training loss over epochs")
plt.grid()
../_images/notebooks_07_application_rnns_grid_cells_10_0.png

Extract representations from epoch = 0 to epoch = 100 (final)#

 In [51]:
representations = []

for rep in single_agent_rate_maps:
    points = rep.T
    norm_points = points / np.linalg.norm(points, axis=1)[:, None]
    representations.append(norm_points)
 In [52]:
print(
    f"There are {representations[0].shape[0]} points in {representations[0].shape[1]}-dimensional space"
)
There are 400 points in 4096-dimensional space

Compute Persistent Homology using \(\texttt{giotto-tda}\)#

 In [19]:
from gtda.homology import WeakAlphaPersistence, VietorisRipsPersistence
from gtda.diagrams import PairwiseDistance
from gtda.plotting import plot_diagram, plot_heatmap
import neurometry.datasets.synthetic as synthetic

Load synthetic 1-sphere, 2-sphere, and 2-torus neural manifolds

 In [20]:
num_points = representations[0].shape[0]
embedding_dim = representations[0].shape[1]

task_points_circle = synthetic.hypersphere(intrinsic_dim=1, num_points=num_points)

_, circle_points = synthetic.synthetic_neural_manifold(
    points=task_points_circle,
    encoding_dim=embedding_dim,
    nonlinearity="linear",
)

norm_circle_points = circle_points / np.linalg.norm(circle_points, axis=1)[:, None]

task_points_sphere = synthetic.hypersphere(intrinsic_dim=2, num_points=num_points)

_, sphere_points = synthetic.synthetic_neural_manifold(
    points=task_points_sphere,
    encoding_dim=embedding_dim,
    nonlinearity="linear",
)

norm_sphere_points = sphere_points / np.linalg.norm(sphere_points, axis=1)[:, None]

task_points_sphere3 = synthetic.hypersphere(intrinsic_dim=3, num_points=num_points)

_, sphere3_points = synthetic.synthetic_neural_manifold(
    points=task_points_sphere3,
    encoding_dim=embedding_dim,
    nonlinearity="linear",
)

norm_sphere3_points = sphere3_points / np.linalg.norm(sphere3_points, axis=1)[:, None]


torus_task_points = synthetic.hypertorus(intrinsic_dim=2, num_points=num_points)

_, torus_points = synthetic.synthetic_neural_manifold(
    points=torus_task_points,
    encoding_dim=embedding_dim,
    nonlinearity="linear",
)

norm_torus_points = torus_points / np.linalg.norm(torus_points, axis=1)[:, None]


torus3_task_points = synthetic.hypertorus(intrinsic_dim=3, num_points=num_points)

_, torus3_points = synthetic.synthetic_neural_manifold(
    points=torus3_task_points,
    encoding_dim=embedding_dim,
    nonlinearity="linear",
)

norm_torus3_points = torus3_points / np.linalg.norm(torus3_points, axis=1)[:, None]
WARNING! Poisson spikes not generated: mean must be non-negative
WARNING! Poisson spikes not generated: mean must be non-negative
WARNING! Poisson spikes not generated: mean must be non-negative
WARNING! Poisson spikes not generated: mean must be non-negative
WARNING! Poisson spikes not generated: mean must be non-negative
 In [12]:
num_points = 100

embedding_dim = 10

task_points_circle = synthetic.hypersphere(intrinsic_dim=1, num_points=num_points)

noisy_circle_points, circle_points = synthetic.synthetic_neural_manifold(
    points=task_points_circle,
    encoding_dim=embedding_dim,
    nonlinearity="tanh",
    poisson_multiplier=1,
    scales=torch.ones(embedding_dim),
)

print(
    f"There are {circle_points.shape[0]} points in {circle_points.shape[1]}-dimensional space"
)
noise level: 7.07%
There are 100 points in 10-dimensional space
 In [13]:
num_points = 100

embedding_dim = 10

task_points_sphere2 = synthetic.hypersphere(intrinsic_dim=2, num_points=num_points)

noisy_sphere2_points, sphere2_points = synthetic.synthetic_neural_manifold(
    points=task_points_sphere2,
    encoding_dim=embedding_dim,
    nonlinearity="tanh",
    poisson_multiplier=1,
    scales=torch.ones(embedding_dim),
)

print(
    f"There are {sphere2_points.shape[0]} points in {sphere2_points.shape[1]}-dimensional space"
)
noise level: 7.07%
There are 100 points in 10-dimensional space
 In [ ]:
num_points = 100

embedding_dim = 10

task_points_torus2 = synthetic.hypertorus(intrinsic_dim=2, num_points=num_points)

noisy_torus2_points, torus2_points = synthetic.synthetic_neural_manifold(
    points=task_points_torus2,
    encoding_dim=embedding_dim,
    nonlinearity="tanh",
    poisson_multiplier=1,
    scales=torch.ones(embedding_dim),
)

print(
    f"There are {torus2_points.shape[0]} points in {sphere2_points.shape[1]}-dimensional space"
)

Load or Compute Vietoris-Rips persistence diagrams

 In [21]:
homology_dimensions = (
    0,
    1,
    2,
    3,
)
VR = VietorisRipsPersistence(homology_dimensions=homology_dimensions)
 In [22]:
try:
    print("Loading Vietoris-Rips persistence diagrams")
    vr_diagrams = np.load("datasets/rnn_grid_cells/single_agent_vr_pers_diagrams.npy")

except:
    print("Computing Vietoris-Rips persistence diagrams")
    vr_diagrams = VR.fit_transform(
        representations
        + [norm_circle_points]
        + [norm_sphere_points]
        + [norm_torus_points]
        + [norm_sphere3_points]
        + [norm_torus3_points]
    )
    np.save("datasets/rnn_grid_cells/single_agent_vr_pers_diagrams.npy", vr_diagrams)


print(
    f"There are {vr_diagrams.shape[0]} persistence diagrams. Each diagram has {vr_diagrams.shape[1]} features (points)."
)
Loading Vietoris-Rips persistence diagrams
There are 25 persistence diagrams. Each diagram has 1635 features (points).

Each feature is a triple \([b, d, q]\), where \(q\) is the dimension, \(b\) is the birth time, \(d\) is the death time

 In [23]:
fig_torus3 = plot_diagram(
    vr_diagrams[-1],
    plotly_params={"title": "Vietoris-Rips Persistence Diagram, 3-torus"},
)
fig_torus3.update_layout(title="Vietoris-Rips Persistence Diagram, 3-torus")

Note: the Poincaré polynomial of a surface is the generating function of its Betti numbers.

the Poincaré polynomial of an \(n\)-torus is \((1+x)^n\), by the Künneth theorem. The Betti numbers are therefore the binomial coefficients.

Thus for the \(3\)-torus, the non-zero Betti numbers are \((1,3,3,1)\).

 In [13]:
fig_sphere3 = plot_diagram(
    vr_diagrams[-2],
    plotly_params={"title": "Vietoris-Rips Persistence Diagram, 3-sphere"},
)
fig_sphere3.update_layout(title="Vietoris-Rips Persistence Diagram, 3-sphere")
 In [14]:
fig_rep_final = plot_diagram(
    vr_diagrams[-6],
    plotly_params={"title": "Vietoris-Rips Persistence Diagram, final representation"},
)
fig_rep_final.update_layout(
    title="Vietoris-Rips Persistence Diagram, final representation"
)

Compute pairwise topological distance (“landscape”)#

 In [15]:
landscape_PD = PairwiseDistance(metric="landscape", n_jobs=-1)

landscape_distance = landscape_PD.fit_transform(vr_diagrams)
 In [20]:
landscape_distance_to_circle = landscape_distance[-5, :-5]
landscape_distance_to_sphere = landscape_distance[-4, :-5]
landscape_distance_to_torus = landscape_distance[-3, :-5]
landscape_distance_to_sphere3 = landscape_distance[-2, :-5]
landscape_distance_to_torus3 = landscape_distance[-1, :-5]
plt.plot(epochs[:-1], landscape_distance_to_circle, "o-", label="1-sphere")
plt.plot(epochs[:-1], landscape_distance_to_sphere, "o-", label="2-sphere")
plt.plot(epochs[:-1], landscape_distance_to_sphere3, "o-", label="3-sphere")
plt.plot(epochs[:-1], landscape_distance_to_torus, "o-", label="2-torus")
plt.plot(epochs[:-1], landscape_distance_to_torus3, "o-", label="3-torus")
plt.xlabel("Training Epoch")
plt.ylabel("Landscape Distance")
plt.title("Topological Distance of RNN Representation to reference topologies")
plt.grid()
plt.legend();
../_images/notebooks_07_application_rnns_grid_cells_31_0.png
 In [19]:
norm_landscape_distance_to_circle = (
    landscape_distance_to_circle - np.min(landscape_distance_to_circle)
) / (np.max(landscape_distance_to_circle) - np.min(landscape_distance_to_circle))

norm_landscape_distance_to_sphere = (
    landscape_distance_to_sphere - np.min(landscape_distance_to_sphere)
) / (np.max(landscape_distance_to_sphere) - np.min(landscape_distance_to_sphere))

norm_landscape_distance_to_sphere3 = (
    landscape_distance_to_sphere3 - np.min(landscape_distance_to_sphere3)
) / (np.max(landscape_distance_to_sphere3) - np.min(landscape_distance_to_sphere3))

norm_landscape_distance_to_torus = (
    landscape_distance_to_torus - np.min(landscape_distance_to_torus)
) / (np.max(landscape_distance_to_torus) - np.min(landscape_distance_to_torus))

norm_landscape_distance_to_torus3 = (
    landscape_distance_to_torus3 - np.min(landscape_distance_to_torus3)
) / (np.max(landscape_distance_to_torus3) - np.min(landscape_distance_to_torus3))

plt.plot(epochs, norm_landscape_distance_to_circle, "o-", label="1-sphere")
plt.plot(epochs, norm_landscape_distance_to_sphere, "o-", label="2-sphere")
plt.plot(epochs, norm_landscape_distance_to_sphere3, "o-", label="3-sphere")
plt.plot(epochs, norm_landscape_distance_to_torus, "o-", label="2-torus")
plt.plot(epochs, norm_landscape_distance_to_torus3, "o-", label="3-torus")
plt.xlabel("Training Epoch")
plt.ylabel("Normalized Landscape Distance")
plt.title("Topological Distance of RNN Representation to reference topologies")
plt.grid()
plt.legend();
../_images/notebooks_07_application_rnns_grid_cells_32_0.png
 In [95]:
landscape_distance_to_torus_diff = (
    np.diff(landscape_distance_to_torus) / landscape_distance_to_torus[:-1]
)
landscape_distance_to_torus3_diff = (
    np.diff(landscape_distance_to_torus3) / landscape_distance_to_torus3[:-1]
)
landscape_distance_to_sphere_diff = (
    np.diff(landscape_distance_to_sphere) / landscape_distance_to_sphere[:-1]
)
landscape_distance_to_sphere3_diff = (
    np.diff(landscape_distance_to_sphere3) / landscape_distance_to_sphere3[:-1]
)
landscape_distance_to_circle_diff = (
    np.diff(landscape_distance_to_circle) / landscape_distance_to_circle[:-1]
)

loss_diff = np.diff(loss_normalized) / loss_aggregated[:-1]

plt.plot(epochs[1:], landscape_distance_to_torus_diff, "o-", label="2-torus")
plt.plot(epochs[1:], landscape_distance_to_torus3_diff, "o-", label="3-torus")
plt.plot(epochs[1:], landscape_distance_to_sphere_diff, "o-", label="2-sphere")
plt.plot(epochs[1:], landscape_distance_to_sphere3_diff, "o-", label="3-sphere")
plt.plot(epochs[1:], landscape_distance_to_circle_diff, "o-", label="1-sphere")
plt.plot(np.linspace(0, 99, 99), 10 * loss_diff, "o-", label="Training Loss", alpha=0.5)
plt.xlabel("Training Epoch")
plt.ylabel("Time Derivative of Landscape Distance /Loss")
plt.legend()
plt.title("Time Derivative of Landscape Distance / Loss")
plt.grid();
../_images/notebooks_07_application_rnns_grid_cells_33_0.png
 In [15]:
error_normalized = (error - np.min(error)) / (np.max(error) - np.min(error))
loss_normalized = (loss_aggregated - np.min(loss_aggregated)) / (
    np.max(loss_aggregated) - np.min(loss_aggregated)
)
 In [16]:
plt.plot(epochs, error_normalized, "o-", label="Topological Distance")
plt.plot(np.linspace(0, 100, 100), loss_normalized, "o-", alpha=0.5, label="Loss")
plt.xlabel("Training Epoch")
plt.ylabel("Topological Distance of Representation to 2-torus")
plt.title("Topological Distance of RNN Representation to 2-Torus")
plt.grid()
plt.legend();
../_images/notebooks_07_application_rnns_grid_cells_35_0.png
 In [23]:
fig_epoch_0 = plot_diagram(
    vr_diagrams[1],
    homology_dimensions=(0, 1, 2),
    plotly_params={"title": "Vietoris-Rips PD, single-agent RNN, epoch=0"},
)
fig_epoch_0.update_layout(title="Vietoris-Rips PD, single-agent RNN, epoch=0")
 In [24]:
fig_epoch_95 = plot_diagram(
    vr_diagrams[-1],
    homology_dimensions=(0, 1, 2),
    plotly_params={"title": "Vietoris-Rips PD, single-agent RNN, epoch=95"},
)
fig_epoch_95.update_layout(title="Vietoris-Rips PD, single-agent RNN, epoch=95")
 In [30]:
sphere_error_normalized = (sphere_error - np.min(sphere_error)) / (
    np.max(sphere_error) - np.min(sphere_error)
)

plt.plot(epochs, error_normalized, "o-", label="Torus")
plt.plot(epochs, sphere_error_normalized, "o-", label="Sphere")
plt.plot(np.linspace(0, 100, 100), loss_normalized, "o-", alpha=0.5, label="Loss")
plt.xlabel("Training Epoch")
plt.ylabel("Topological Distance/Loss")
plt.legend();
 Out [30]:
<matplotlib.legend.Legend at 0x7f8f4ad0f0d0>
../_images/notebooks_07_application_rnns_grid_cells_38_1.png

Estimate rank of connectivity matrix#

Get final model (epoch \(=100\))

Compare run-times of \(\texttt{giotto-tda}, \texttt{ripser}, \texttt{giotto-ph}\)#

 In [20]:
from gtda.homology import WeakAlphaPersistence, VietorisRipsPersistence
from ripser import ripser
from persim import plot_diagrams
from gph import ripser_parallel

import time


final_representation = representations[-1]


homology_dimensions = (
    0,
    1,
    2,
)
VR = VietorisRipsPersistence(homology_dimensions=homology_dimensions)


gtda_start = time.time()
gtda_vr_diagrams = VR.fit_transform([final_representation])
gtda_end = time.time()
print(
    f"Time to compute Vietoris-Rips persistence diagrams in giotto-tda: {gtda_end - gtda_start:.2f}"
)


ripser_start = time.time()
diagrams = ripser(representations[-1], maxdim=2)["dgms"]
ripser_end = time.time()
print(
    f"Time to compute Vietoris-Rips persistence diagrams in ripser: {ripser_end - ripser_start:.2f}"
)


gph_start = time.time()
gph_vr_diagrams = ripser_parallel(final_representation, maxdim=2, n_threads=-1)
gph_end = time.time()
print(
    f"Time to compute Vietoris-Rips persistence diagrams in giotto-ph: {gph_end - gph_start:.2f} sec"
)
Time to compute Vietoris-Rips persistence diagrams in giotto-tda: 4.770987272262573
Time to compute Vietoris-Rips persistence diagrams in ripser: 15.016701698303223
Time to compute Vietoris-Rips persistence diagrams in giotto-ph: 3.094177722930908
 In [37]:
plot_diagrams(gph_vr_diagrams["dgms"])
../_images/notebooks_07_application_rnns_grid_cells_43_0.png
 In [70]:
diags = ripser_parallel(
    representations[-1], maxdim=2, coeff=2, metric="manhattan", n_threads=-1
)["dgms"]

plot_diagrams(diags)
../_images/notebooks_07_application_rnns_grid_cells_44_0.png
 In [71]:
gph_diagrams = {}

for i in range(len(epochs)):
    gph_diagrams[epochs[i]] = ripser_parallel(
        representations[i], maxdim=2, coeff=2, metric = 'euclidean',n_threads=-1
    )["dgms"]

plot_diagrams(gph_diagrams["final"])

Isolate Grid Cells (cells with high grid score)#

 In [53]:
grid_scores_all_epochs = []
band_scores_all_epochs = []
border_scores_all_epochs = []
for epoch in epochs:
    scores_dir = (
        "/scratch/facosta/rnn_grid_cells/" + model_folder + model_parameters + "scores/"
    )
    grid_scores_all_epochs.append(
        np.load(scores_dir + f"score_60_single_agent_epoch_{epoch}.npy")
    )
    band_scores_all_epochs.append(
        np.load(scores_dir + f"band_scores_single_agent_epoch_{epoch}.npy")
    )
    border_scores_all_epochs.append(
        np.load(scores_dir + f"border_scores_single_agent_epoch_{epoch}.npy")
    )
 In [40]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

mean_grid_scores = [np.mean(scores) for scores in grid_scores_all_epochs]

ax[0].plot(epochs, mean_grid_scores, "o-");
../_images/notebooks_07_application_rnns_grid_cells_48_0.png
 In [50]:
# get a sort of grid scores at last epoch
final_epoch_grid_score_sort = np.argsort(grid_scores_all_epochs[-1])

# apply sort to all grid scores
sorted_grid_scores_all_epochs = []

for grid_scores in grid_scores_all_epochs:
    sorted_grid_scores_all_epochs.append(grid_scores[final_epoch_grid_score_sort])
# sorted_grid_scores = [np.mean(score[sort]) for score in grid_scores_all_epochs]

see 40 units with highest grid scores:

 In [68]:
plot_rate_map(final_epoch_grid_score_sort[-40:], None, single_agent_activations[-1])
../_images/notebooks_07_application_rnns_grid_cells_51_0.png

See 40 units with lowest grid score:

 In [69]:
plot_rate_map(final_epoch_grid_score_sort[:40], None, single_agent_activations[-1])
../_images/notebooks_07_application_rnns_grid_cells_53_0.png
 In [117]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

ax[0].hist(grid_scores_all_epochs[-1], bins=20)
ax[0].set_xlabel("Grid scores")
ax[0].set_ylabel("Frequency")
ax[0].set_title("Grid scores at last epoch")

ax[1].hist(band_scores_all_epochs[-1], bins=20)
ax[1].set_xlabel("Band scores")
ax[1].set_ylabel("Frequency")
ax[1].set_title("Band scores at last epoch")

ax[2].hist(border_scores_all_epochs[-1], bins=20)
ax[2].set_xlabel("Border scores")
ax[2].set_ylabel("Frequency")
ax[2].set_title("Border scores at last epoch")

plt.tight_layout()
../_images/notebooks_07_application_rnns_grid_cells_54_0.png
 In [111]:
num_top_bottom = 40

lowest_grid_scores_over_time = [
    np.mean(sorted_grid_scores_all_epochs[i][:num_top_bottom])
    for i in range(len(epochs))
]

top_grid_scores_over_time = [
    np.mean(sorted_grid_scores_all_epochs[i][-num_top_bottom:])
    for i in range(len(epochs))
]

average_grid_scores_over_time = [
    np.mean(sorted_grid_scores_all_epochs[i]) for i in range(len(epochs))
]
 In [113]:
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(
    epochs[:-1] + [100],
    lowest_grid_scores_over_time,
    "o-",
    label=f"Mean: bottom {num_top_bottom} grid scores",
)
ax.plot(
    epochs[:-1] + [100],
    average_grid_scores_over_time,
    "o-",
    label="Mean: all grid scores",
)
ax.plot(
    epochs[:-1] + [100],
    top_grid_scores_over_time,
    "o-",
    label=f"Mean: top {num_top_bottom} grid scores",
)

ax.set_xlabel("Training Epoch", fontsize=12)
ax.set_ylabel("Grid Scores", fontsize=12)
ax.set_title("Grid Scores over Training", fontsize=14)
ax.tick_params(axis="both", which="major", labelsize=10)


ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
ax.legend()

plt.tight_layout()

plt.show()

# plt.savefig('grid_scores_over_training.png', dpi=300)
../_images/notebooks_07_application_rnns_grid_cells_56_0.png

Isolate Band Cells (cells with high band score)#

 In [ ]:

Isolate Border cells (cells with high border score)#

 In [ ]:

Compute Spatial Autocorrelation + UMAP#

 In [10]:
from neurometry.datasets.rnn_grid_cells.scores import GridScorer
 In [11]:
from tqdm import tqdm


def compute_spatial_autocorrelation(res, rate_map_single_agent, scorer):
    print("Computing spatial auto-correlation...")
    _, _, _, _, spatial_autocorrelation, _ = zip(
        *[scorer.get_scores(rm.reshape(res, res)) for rm in tqdm(rate_map_single_agent)]
    )

    spatial_autocorrelation = np.array(spatial_autocorrelation)

    return spatial_autocorrelation
 In [19]:
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)

box_width = 2.2
box_height = 2.2

res = 20

coord_range = ((-box_width / 2, box_width / 2), (-box_height / 2, box_height / 2))

masks_parameters = zip(starts, ends.tolist())
scorer = GridScorer(res, coord_range, masks_parameters)


# spatial_autocorrelations = []

# for _, epoch in enumerate(epochs):

spatial_autocorrelation = compute_spatial_autocorrelation(
    res, single_agent_rate_maps[-1], scorer
)

print(spatial_autocorrelation.shape)
Computing spatial auto-correlation...

100%|██████████| 4096/4096 [00:35<00:00, 114.23it/s]
 In [32]:
def z_standardize(matrix):
    return (matrix - np.mean(matrix, axis=0)) / np.std(matrix, axis=0)


def vectorized_spatial_autocorrelation_matrix(spatial_autocorrelation):
    num_cells = spatial_autocorrelation.shape[0]
    num_bins = spatial_autocorrelation.shape[1] * spatial_autocorrelation.shape[2]

    spatial_autocorrelation_matrix = np.zeros((num_bins, num_cells))

    for i in range(num_cells):
        vector = spatial_autocorrelation[i].flatten()

        spatial_autocorrelation_matrix[:, i] = vector

    return z_standardize(spatial_autocorrelation_matrix)
 In [33]:
spatial_autocorrelation_matrix = vectorized_spatial_autocorrelation_matrix(
    spatial_autocorrelation
)

print(spatial_autocorrelation_matrix.shape)
 In [43]:
import umap

reducer_2d = umap.UMAP(n_components=2,random_state=42)

embedding = reducer_2d.fit_transform(spatial_autocorrelation_matrix.T)

print(embedding.shape)
 Out [43]:
(4096, 2)
 In [70]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Plot for Grid Scores
sc1 = axs[0].scatter(
    embedding[:, 0], embedding[:, 1], c=grid_scores_all_epochs[-1], cmap="viridis"
)
axs[0].set_xlabel("UMAP 1")
axs[0].set_ylabel("UMAP 2")
axs[0].set_title("UMAP of Spatial Autocorrelations; Color by Grid Score")
fig.colorbar(sc1, ax=axs[0], orientation="vertical", label="Grid Score")

# Plot for Band Scores
sc2 = axs[1].scatter(
    embedding[:, 0], embedding[:, 1], c=band_scores_all_epochs[-1], cmap="viridis"
)
axs[1].set_xlabel("UMAP 1")
axs[1].set_ylabel("UMAP 2")
axs[1].set_title("UMAP of Spatial Autocorrelations; Color by Band Score")
fig.colorbar(sc2, ax=axs[1], orientation="vertical", label="Band Score")

# Plot for Border Scores
sc3 = axs[2].scatter(
    embedding[:, 0], embedding[:, 1], c=border_scores_all_epochs[-1], cmap="viridis"
)
axs[2].set_xlabel("UMAP 1")
axs[2].set_ylabel("UMAP 2")
axs[2].set_title("UMAP of Spatial Autocorrelations; Color by Border Score")
fig.colorbar(sc3, ax=axs[2], orientation="vertical", label="Border Score")

plt.tight_layout()
../_images/notebooks_07_application_rnns_grid_cells_68_0.png
 In [71]:
reducer_3d = umap.UMAP(n_components=3, random_state=42)

embedding_3d = reducer_3d.fit_transform(spatial_autocorrelation_matrix.T)

print(embedding.shape)
(4096, 2)
 In [72]:
import plotly.graph_objects as go

fig = go.Figure(
    data=[
        go.Scatter3d(
            x=embedding_3d[:, 0],
            y=embedding_3d[:, 1],
            z=embedding_3d[:, 2],
            mode="markers",
            marker=dict(
                size=5,
                color=grid_scores_all_epochs[-1],
                colorscale="Viridis",
                opacity=0.8,
                colorbar=dict(title="Grid Score"),
            ),
        )
    ]
)

fig.update_layout(
    title="3D UMAP Visualization of Spatial Autocorrelations; Color by Grid Score",
    scene=dict(xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3"),
    margin=dict(l=0, r=0, b=0, t=30),
)

fig.show()
 In [73]:
fig = go.Figure(
    data=[
        go.Scatter3d(
            x=embedding_3d[:, 0],
            y=embedding_3d[:, 1],
            z=embedding_3d[:, 2],
            mode="markers",
            marker=dict(
                size=5,
                color=band_scores_all_epochs[-1],
                colorscale="Viridis",
                opacity=0.8,
                colorbar=dict(title="Band Score"),
            ),
        )
    ]
)

fig.update_layout(
    title="3D UMAP Visualization of Spatial Autocorrelations; Color by Band Score",
    scene=dict(xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3"),
    margin=dict(l=0, r=0, b=0, t=30),
)

fig.show()
 In [74]:
fig = go.Figure(
    data=[
        go.Scatter3d(
            x=embedding_3d[:, 0],
            y=embedding_3d[:, 1],
            z=embedding_3d[:, 2],
            mode="markers",
            marker=dict(
                size=5,
                color=border_scores_all_epochs[-1],
                colorscale="Viridis",
                opacity=0.8,
                colorbar=dict(title="Border Score"),
            ),
        )
    ]
)

fig.update_layout(
    title="3D UMAP Visualization of Spatial Autocorrelations; Color by Border Score",
    scene=dict(xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3"),
    margin=dict(l=0, r=0, b=0, t=30),
)

fig.show()
 In [29]:
from neurometry.datasets.load_rnn_grid_cells import plot_rate_map

plot_rate_map([3617, 0, 0, 0, 1], 40, single_agent_activations[-1])
../_images/notebooks_07_application_rnns_grid_cells_73_0.png

Discover “modules” through clustering / dim reduction? (see Gardner Extended Data Fig. 2)#

 In [26]:
# parent_dir = os.getcwd() + "/datasets/rnn_grid_cells/"

parent_dir = "/scratch/facosta/rnn_grid_cells/"

single_model_folder = "Single agent path integration/Seed 1 weight decay 1e-06/"
single_model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"

saved_model_single_agent = torch.load(
    parent_dir + single_model_folder + single_model_parameters + "final_model.pth"
)


print(f"The model is a dictionary with keys {saved_model_single_agent.keys()}")
The model is a dictionary with keys odict_keys(['encoder.weight', 'RNN.weight_ih_l0', 'RNN.weight_hh_l0', 'decoder.weight'])

Extract the recurrent connectivity matrix:

 In [27]:
W = saved_model_single_agent["RNN.weight_hh_l0"].detach().numpy()
print(f"W has dimensions {W.shape}")
W has dimensions (4096, 4096)

Find singular values of \(W\):

 In [33]:
singular_values = np.linalg.svd(W, compute_uv=False)

Plot singular value spectrum:

 In [57]:
ev_threshold = 0.9

explained_variance = singular_values**2 / np.sum(singular_values**2)

cumulative_explained_variance = np.cumsum(explained_variance)

plt.plot(cumulative_explained_variance, "o-")

plt.xlabel("Number of components")
plt.ylabel("Cumulative explained variance")

plt.yscale("log")
plt.grid()


plt.title("Cumulative explained variance of singular values of RNN weight matrix")

plt.hlines(
    ev_threshold, 0, len(cumulative_explained_variance), linestyles="dashed", colors="r"
)

plt.vlines(
    np.where(cumulative_explained_variance >= ev_threshold)[0][0],
    0,
    ev_threshold,
    linestyles="dashed",
    colors="r",
)

# show number of components to explain 90% of variance on x-axis
plt.text(
    np.where(cumulative_explained_variance >= ev_threshold)[0][0],
    0.1,
    f"Number of components for {100*ev_threshold} variance: {np.where(cumulative_explained_variance >= ev_threshold)[0][0]}",
)


num_components = np.where(cumulative_explained_variance >= ev_threshold)[0][0] + 1

print(
    f"Number of components to explain {100*ev_threshold}% of variance: {num_components}"
)
Number of components to explain 90.0% of variance: 372
../_images/notebooks_07_application_rnns_grid_cells_81_1.png

Dual-Agent RNN#

Load activations across training epochs#

 In [97]:
epochs = list(range(0, 100, 5))
(
    dual_agent_activations,
    dual_agent_rate_maps,
    dual_agent_state_points,
) = load_activations(epochs, version="dual", verbose=True)
Epoch 0 found!!! :D
Epoch 5 found!!! :D
Epoch 10 found!!! :D
Epoch 15 found!!! :D
Epoch 20 found!!! :D
Epoch 25 found!!! :D
Epoch 30 found!!! :D
Epoch 35 found!!! :D
Epoch 40 found!!! :D
Epoch 45 found!!! :D
Epoch 50 found!!! :D
Epoch 55 found!!! :D
Epoch 60 found!!! :D
Epoch 65 found!!! :D
Epoch 70 found!!! :D
Epoch 75 found!!! :D
Epoch 80 found!!! :D
Epoch 85 found!!! :D
Epoch 90 found!!! :D
Epoch 95 found!!! :D
Loaded epochs [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95] of dual agent model.
There are 4096 grid cells with 20 x 20 environment resolution, averaged over 50 trajectories.
There are 20000 data points in the 4096-dimensional state space.
There are 400 data points averaged over 50 trajectories in the 4096-dimensional state space.

Plot final activations#

 In [98]:
plot_rate_map(40, dual_agent_activations[-1])
../_images/notebooks_07_application_rnns_grid_cells_86_0.png

Extract dual agent representations from epoch = 0 to epoch = 95#

 In [99]:
dual_representations = []

for rep in dual_agent_rate_maps:
    points = rep.T
    norm_points = points / np.linalg.norm(points, axis=1)[:, None]
    dual_representations.append(norm_points)

Load training loss#

 In [103]:
model_folder = "Dual agent path integration disjoint PCs/Seed 1 weight decay 1e-06/"
model_parameters = "steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_1e-06/"

loss_path = (
    os.getcwd()
    + "/datasets/rnn_grid_cells/"
    + model_folder
    + model_parameters
    + "loss.npy"
)

loss = np.load(loss_path)

loss_aggregated = np.mean(loss.reshape(-1, 1000), axis=1)

loss_normalized = (loss_aggregated - np.min(loss_aggregated)) / (
    np.max(loss_aggregated) - np.min(loss_aggregated)
)

plt.plot(np.linspace(0, 100, 100), loss_normalized)
plt.xlabel("Epochs")
plt.ylabel("Normalized trainingloss")
plt.title("Training loss over epochs")
plt.grid()
../_images/notebooks_07_application_rnns_grid_cells_90_0.png

Estimate Dimension#

 In [3]:
neural_manifold = rate_maps.T


num_trials = 10
# methods = [method for method in dir(skdim.id) if not method.startswith("_")]
methods = ["MLE", "KNN", "TwoNN", "CorrInt", "lPCA"]

id_estimates = {}
for method_name in methods:
    method = getattr(skdim.id, method_name)()
    estimates = np.zeros(num_trials)
    for trial_idx in range(num_trials):
        method.fit(neural_manifold)
        estimates[trial_idx] = np.mean(method.dimension_)
    id_estimates[method_name] = estimates
 In [6]:
neural_manifold.shape
 Out [6]:
(400, 4096)
 In [18]:
# make side by side plots
fig, axes = plt.subplots(1, 2, figsize=(20, 6))

for i, method in enumerate(methods):
    y = id_estimates[method]
    x = np.repeat(i, len(y))
    axes[0].scatter(x, y, label=method)
    axes[1].scatter(x, y, label=method)

axes[0].set_xticks(range(len(methods)))
axes[0].set_xticklabels(methods)
axes[0].set_xlabel("Dimension Estimation Method")
axes[0].set_ylabel("Values")
axes[0].set_title("Estimates of Intrinsic Dimensionality of Neural Manifold")
axes[0].legend()

axes[1].set_xticks(range(len(methods)))
axes[1].set_xticklabels(methods)
axes[1].set_xlabel("Dimension Estimation Method")
axes[1].set_ylabel("Values")
axes[1].set_ylim([0, 40])
axes[1].set_title("Zoom in: Estimates of Intrinsic Dimensionality of Neural Manifold")
axes[1].legend();
../_images/notebooks_07_application_rnns_grid_cells_94_0.png

estimate extrinsic with PCA, then do nonlinear dim est